# -*- coding: utf-8 -*-
'''
Created on 2016年12月26日

@author: ZhuJiahui
'''


import numpy as np
from scipy import special


def delta_function_s(dir_alpha, selected_num):
    '''
    LDA里面的delta函数
    :param dir_alpha: V维Dirichlet参数行向量
    :param selected_num: 选择值最大的前selectd_num个做运算
    :return: delta函数值
    '''

    sorted_data = np.sort(dir_alpha)
    selected_data = sorted_data[(len(dir_alpha) - selected_num) :]
    numerator = np.prod(special.gamma(selected_data))
    denominator = special.gamma(np.sum(selected_data))

    return np.true_divide(numerator, denominator)


def delta_function(dir_alpha):
    '''
    LDA里面的delta函数
    :param dir_alpha: V维Dirichlet参数行向量
    :return: delta函数值
    '''

    numerator = np.prod(special.gamma(dir_alpha))
    denominator = special.gamma(np.sum(dir_alpha))

    return np.true_divide(numerator, denominator)


def dcm_margin(topic_vsm, dir_alpha):
    '''
    计算主题集合的边缘概率
    采用Dirichlet Compound Multinomial分布
    :param topic_vsm: 主题向量空间
    :param dir_alpha: V维Dirichlet参数行向量
    :return 边缘概率值
    '''

    selected_num = 10  # 最好不要全部选取整个特征维度
    vsm_shape = topic_vsm.shape
    
    TW_DCM = 1.0
    
    if len(vsm_shape) == 1:
        # 单个主题，是一个行向量
        up_formular = 1.0
        for j in range(len(dir_alpha)):
            if (topic_vsm[j] > 0.0001):
                up_formular = up_formular * np.power(topic_vsm[j], (dir_alpha[j] - 1))
                
        TW_DCM = np.true_divide(up_formular, delta_function(dir_alpha))
    
    else:
        topic_size = len(topic_vsm)
        for i in range(topic_size):
            up_formular = 1.0
            for j in range(len(dir_alpha)):
                if (topic_vsm[i][j] > 0.0001):
                    up_formular = up_formular * np.power(topic_vsm[i][j], (dir_alpha[j] - 1))
        
            this_DCM = np.true_divide(up_formular, delta_function(dir_alpha))
            TW_DCM = TW_DCM * this_DCM
            
    return TW_DCM


def dcm_margin2(topic_vsm, dir_alpha):
    '''
    计算主题集合的边缘概率
    采用Dirichlet Compound Multinomial分布
    :param topic_vsm: 主题向量空间
    :param dir_alpha: V维Dirichlet参数行向量
    :return 边缘概率值
    '''
    
    # selected_num = 10  # 最好不要全部选取整个特征维度
    vsm_shape = topic_vsm.shape
    
    TW_DCM = 1.0
    delta_denominator = special.gamma(np.sum(dir_alpha))
    
    if len(vsm_shape) == 1:
        # 单个主题，是一个行向量
        TW_DCM = delta_denominator
        for j in range(len(dir_alpha)):
            if (topic_vsm[j] > 0.0001):
                TW_DCM = TW_DCM * np.power(topic_vsm[j], (dir_alpha[j] - 1)) / special.gamma(dir_alpha[j])
            else:
                TW_DCM = TW_DCM / special.gamma(dir_alpha[j])
    
    else:
        topic_size = len(topic_vsm)
        for i in range(topic_size):
            TW_DCM = TW_DCM * delta_denominator
            for j in range(len(dir_alpha)):
                if (topic_vsm[i][j] > 0.0001):
                    TW_DCM = TW_DCM * np.power(topic_vsm[i][j], (dir_alpha[j] - 1)) / special.gamma(dir_alpha[j])
                else:
                    TW_DCM = TW_DCM / special.gamma(dir_alpha[j])
            
    return TW_DCM


def test_delta1():
    dimension = 1000
    alpha = 0.1 * np.ones(dimension)
    print(delta_function(alpha))


def test_dcm_margin2():
    dimension = 1000
    topic_vsm = 0.001 * np.ones(dimension)
    dir_alpha = 0.01 * np.ones(dimension)
    
    print(dcm_margin2(topic_vsm, dir_alpha))

if __name__ == '__main__':
    test_dcm_margin2()
    
    